#  This file is part of Pynguin.
#
#  SPDX-FileCopyrightText: 2019–2025 Pynguin Contributors
#
#  SPDX-License-Identifier: MIT
#
"""Provides some mutation related utilities."""

from __future__ import annotations

from typing import TYPE_CHECKING, TypeVar

import pynguin.configuration as config
from pynguin.utils import randomness

if TYPE_CHECKING:
    from collections.abc import Callable


T = TypeVar("T")


def alpha_exponent_insertion(
    elements: list[T],
    value_supplier: Callable[[], T | None],
    alpha: float = 0.5,
    exponent: int = 0,
) -> bool:
    """Provides an alpha-exponent insertion algorithm.

    Repeatedly inserts a new element generated by value_supplier into the given
    elements at a random position as long as next_float() < alpha^exponent holds.
    Exponent is increased after each insertion, thus lowering the chance for another
    insertion.

    Args:
        elements: the elements into which new elements should be inserted.
        value_supplier: supplies the elements that are inserted.
        alpha: the used alpha value.
        exponent: start value of the exponent.

    Returns:
        True, iff at least one element was inserted.
    """
    assert 0 < alpha < 1
    pos = 0
    changed = False
    while randomness.next_float() <= pow(alpha, exponent):
        # Randomize the position for each insertion.
        if len(elements) > 0:
            pos = randomness.next_int(0, len(elements) + 1)

        exponent += 1
        value = value_supplier()
        if value is None:
            # Supplier is exhausted
            return changed

        elements.insert(
            pos,
            value,
        )
        changed = True

    return changed


def multiple_alpha_exponent_insertion(
    elements: list,
    shape: list,
    insertion_supplier: Callable,
    alpha: float = 0.5,
    exponent: int = 0,
) -> tuple[list, bool]:
    """Provides an alpha-exponent insertion algorithm for a nested list with a shape.

    Args:
        elements: the elements into which new elements should be inserted.
        shape: the shape of the elements.
        insertion_supplier: supplies the elements that are inserted.
        alpha: the used alpha value.
        exponent: start value of the exponent.

    Returns:
        True, if at least one element was inserted.
    """
    assert 0 < alpha < 1

    chosen_axis = randomness.next_int(0, len(shape))
    axis_size = shape[chosen_axis]

    if axis_size >= config.configuration.pynguinml.max_shape_dim:
        return elements, False

    free_slots = config.configuration.pynguinml.max_shape_dim - axis_size

    insertion_indices = []
    while randomness.next_float() <= pow(alpha, exponent):
        pos = randomness.next_int(0, axis_size + 1)
        insertion_indices.append(pos)
        exponent += 1

    # trim if the number of planned insertions exceeds the free slots
    if len(insertion_indices) > free_slots:
        insertion_indices = insertion_indices[:free_slots]

    if not insertion_indices:
        return elements, False

    # sort in ascending order for consistent insertion
    insertion_indices.sort()

    return _insert_indices_at_axis(
        elements, chosen_axis, insertion_indices, insertion_supplier
    ), True


def _insert_indices_at_axis(
    array: list, axis: int, insertion_indices: list, insertion_supplier: Callable
) -> list:
    """Recursively inserts new elements at the specified indices along the given axis.

    - When axis == 0, we are at the level where insertion occurs: we insert new
      elements into the list.
    - For deeper axes (axis > 0), apply recursively to each subarray.
    """

    def _create_subarray_like(template):
        if isinstance(template, list):
            return [_create_subarray_like(template[0]) for _ in range(len(template))]
        return insertion_supplier()

    if axis == 0:
        new_array = []
        current_index = 0
        insertion_iter = iter(insertion_indices)
        next_insertion = next(insertion_iter, None)
        for elem in array:
            while next_insertion is not None and current_index == next_insertion:
                new_elem = _create_subarray_like(elem) if array else insertion_supplier()
                if new_elem is not None:
                    new_array.append(new_elem)
                current_index += 1
                next_insertion = next(insertion_iter, None)
            new_array.append(elem)
            current_index += 1
        # If any insertion positions remain, append new elements at the end.
        while next_insertion is not None:
            new_elem = _create_subarray_like(array[0]) if array else insertion_supplier()
            if new_elem is not None:
                new_array.append(new_elem)
            next_insertion = next(insertion_iter, None)
        return new_array
    return [
        _insert_indices_at_axis(subarray, axis - 1, insertion_indices, insertion_supplier)
        for subarray in array
    ]


def remove_indices_at_axis(array: list, axis: int, deletion_indices: list):
    """Recursively removes elements at specified indices along a given axis.

    Args:
        array (list): A nested list representing a multidimensional array.
        axis (int): The axis along which elements should be removed.
        deletion_indices (list[int]): A list of indices specifying elements to remove.

    Returns:
        list: The modified array with elements removed along the specified axis.
    """
    if axis == 0:
        return [elem for idx, elem in enumerate(array) if idx not in deletion_indices]
    return [remove_indices_at_axis(subarray, axis - 1, deletion_indices) for subarray in array]


def apply_random_replacement(
    array: list, p: float, replacement_supplier: Callable
) -> tuple[list, bool]:
    """Recursively replaces numeric elements in a nested list with a probability `p`.

    Args:
        array: The nested list to process.
        p: The probability of replacing each numeric element.
        replacement_supplier: A function that provides replacements for elements.

    Returns:
        tuple[list, bool]: The modified array and a boolean if any changes were made.
    """
    changed = False
    new_array = []

    for elem in array:
        if isinstance(elem, list):
            replaced_subarray, sub_changed = apply_random_replacement(elem, p, replacement_supplier)
            new_array.append(replaced_subarray)
            if sub_changed:
                changed = True
        elif randomness.next_float() < p:
            replacement = replacement_supplier(elem)
            new_array.append(replacement)
            if replacement != elem:
                changed = True
        else:
            new_array.append(elem)

    return new_array, changed
